import torch
from torch import nn
from d2l import torch as d2l

if __name__ == '__main__':
    X = torch.arange(16, dtype=torch.float32).reshape((1,1,4,4))
    print(X)
    # (3,3)的汇聚窗口,默认步幅是3
    pool2d = nn.MaxPool2d(3)
    print(pool2d(X))
    # padding=0, stride=3就是默认值
    pool2d = nn.MaxPool2d(3,padding=1, stride=2)
    print(pool2d(X))


    print("torch.cat")
    X = torch.cat((X, X + 1), 1)
    print(X)

    pool2d = nn.MaxPool2d(3,padding=1, stride=2)
    print(pool2d(X))